Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix complex arguments and implement float16 lowering #2403

Open
wants to merge 18 commits into
base: complex
Choose a base branch
from

Conversation

enwask
Copy link
Contributor

@enwask enwask commented Jul 11, 2024

Fixes some issues that prevented complex float arguments from being passed correctly. Also implements lowering for half-precision floats with the _Float16 type.

Some remaining tasks:

  • MPI with complex floating types
  • Look into MPI with float16, if possible
  • Type conversion is done by modifying the global ctypes_vector_mappet; need some more tests for operator building/rebuilding/pickling with float16, or potentially a different (hopefully less brittle) solution
  • More test coverage of existing behavior applied to float16 types

Copy link
Contributor

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some quick comments

devito/symbolics/extended_dtypes.py Outdated Show resolved Hide resolved
devito/symbolics/extended_dtypes.py Outdated Show resolved Hide resolved
devito/symbolics/printer.py Outdated Show resolved Hide resolved
devito/tools/dtypes_lowering.py Outdated Show resolved Hide resolved
devito/symbolics/extended_dtypes.py Outdated Show resolved Hide resolved
devito/passes/iet/languages/CXX.py Outdated Show resolved Hide resolved
devito/passes/iet/dtypes.py Outdated Show resolved Hide resolved
devito/passes/iet/definitions.py Outdated Show resolved Hide resolved
devito/passes/iet/dtypes.py Outdated Show resolved Hide resolved
@georgebisbas georgebisbas added API api (symbolics, types, ...) compiler labels Jul 11, 2024
devito/data/allocators.py Outdated Show resolved Hide resolved
devito/ir/iet/nodes.py Outdated Show resolved Hide resolved
devito/ir/iet/visitors.py Show resolved Hide resolved
devito/passes/iet/dtypes.py Outdated Show resolved Hide resolved
devito/symbolics/extended_dtypes.py Outdated Show resolved Hide resolved
devito/symbolics/extended_dtypes.py Outdated Show resolved Hide resolved
devito/symbolics/extended_dtypes.py Show resolved Hide resolved
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@enwask enwask force-pushed the complex branch 2 times, most recently from 59f59b8 to 493c1e8 Compare July 26, 2024 14:28
params_mapper = {}

# Lower scalar float16s to pointers and dereference them
for s in FindSymbols('scalars').visit(iet):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just do visit(iet.parameters) directly instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like not, since the actual types I'm mapping (e.g. Constant) aren't nodes they are only caught by FindSymbols if it's as a reference within some other expression, so we need to visit the body. That said I can probably make this marginally more efficient by making a set of parameters beforehand and checking membership that way

@classmethod
def _load_dtype_mappings(cls, **kwargs):
lang: type[LangBB] = cls._Target.DataManager.lang
ctypes_vector_mapper.update(lang.mapper.get('types', {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that's a bit tricky because this updates a global ctypes_vector_mapper which might lead to odd behavior building multiple operators with different languages.
Do you know where it's called and needs those types ? I.e can the mapper be "local" to the operator and passed there?

devito/ir/iet/nodes.py Outdated Show resolved Hide resolved
devito/passes/iet/dtypes.py Outdated Show resolved Hide resolved
devito/passes/iet/dtypes.py Outdated Show resolved Hide resolved
devito/passes/iet/dtypes.py Outdated Show resolved Hide resolved
@@ -15,6 +18,62 @@
}


class NoDeclStruct(ctypes.Structure):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZoeLeibowitz I may be wrong but don't you need this type too? please see also where it's used and comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have different use cases because I need a local composite type which still generates a struct in the header file (i.e still produces a struct definition)

devito/symbolics/extended_dtypes.py Outdated Show resolved Hide resolved


def lower_complex(iet, lang, compiler):
def lower_dtypes(iet, lang, compiler, sregistry):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can this not be an @iet_pass and still work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called from an @iet_pass in definitions here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah OK , I maybe have told Mathias already, but imho that thing goes straight into operator/operator.py::Operator::_lower_iet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, I can make this change

Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments, looking good, but still needs tests

devito/ir/iet/visitors.py Outdated Show resolved Hide resolved
devito/ir/iet/visitors.py Outdated Show resolved Hide resolved
devito/ir/iet/visitors.py Outdated Show resolved Hide resolved
devito/symbolics/printer.py Outdated Show resolved Hide resolved
@FabioLuporini
Copy link
Contributor

but still needs tests

seconded, this is crucial

@enwask
Copy link
Contributor Author

enwask commented Jul 29, 2024

Yeah my plan was to write some more extensive tests for complex (since the existing ones didn't catch all of the break cases I've found) and add in a few for half. I want to get all of these changes out of the way first, and then there's a few more things to work out such as math functions being assigned to float16 arrays. Maybe this should be a draft actually?

Copy link
Contributor

@EdCaunt EdCaunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests looking good. I think there should also be a test to ensure that the use of fp16/complex gets propagated into the variables introduced by CSE

@@ -644,16 +644,18 @@ def test_tensor(self, func1):
def test_complex(self, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The introduction of an override is a good idea here.

I think this wants separating out into a couple of tests:

  • Test that the correct expressions are generated
  • Test that you can specify derivatives of complex fields
  • Test that you can override complex values
  • Test that you can mix complex and real symbols on the RHS

Note: do we also need to test that halo exchanges work with complex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will add some tests for those points.

Re: halo exchanges, I'll get back to you when I figure out what a halo exchange is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A halo exchange is where each MPI rank swaps the data along its edges with its neighbours before advancing to the next timestep. It would probably worth checking that the MPI communication doesn't garble the data somehow when moving it between ranks. You would probably need some equations like:

from devito import Grid, TimeFunction, Eq, Operator

grid = Grid(shape=(4, 4))
x, y = grid.dimensions
t = grid.stepping_dim

u = TimeFunction(name='u', grid=grid)
v = TimeFunction(name='v', grid=grid)

eqns = [Eq(u[t+1, x, y], 1+sympy.I), Eq(v[t+1, x, y], sympy.I*u[t, x, y+1])]

op = Operator(eqns)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I think what I'll do is leave these general complex arithmetic tests where they are (in test_operator and test_gpu_common) and add these lower-level ones to the test_dtypes module. Except this halo exchange one goes in test_mpi I guess

devito/ir/iet/visitors.py Show resolved Hide resolved
tests/test_dtypes.py Outdated Show resolved Hide resolved
tests/test_dtypes.py Outdated Show resolved Hide resolved
@enwask enwask requested a review from EdCaunt July 31, 2024 15:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API api (symbolics, types, ...) compiler
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants